#!/usr/bin/python

import code,pickle,sys,os,re
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from pylab import *
from optparse import OptionParser
import ocrolib
from ocrolib import dbtables,quant

parser = OptionParser("""
usage: %prog [options] chars.db output.db

""")

parser.add_option("-D","--display",help="display chars",action="store_true")
parser.add_option("-v","--verbose",help="verbose output",action="store_true")
parser.add_option("-t","--table",help="table name",default="chars")
parser.add_option("-e","--epsilon",help="epsilon",type=float,default=0.1)
parser.add_option("-o","--overwrite",help="overwrite output if it exists",action="store_true")

class FastCluster:
    def __init__(self,eps=0.05):
        self.eps = eps
        self.ex = ocrolib.ScaledFE()
        self.dc = ocrolib.EdistComp()
        self.classes = []
        self.counts = []
        self.total = 0
    def add(self,c,cls=None):
        self.total += 1
        c /= sqrt(sum(c**2))
        v = self.ex.extract(c)
        i = self.dc.find(v,self.eps)
        if i<0:
            self.dc.add(v)
            self.classes.append({cls:1})
            self.counts.append(1)
            return len(self.counts)-1
        else:
            self.classes[i][cls] = self.classes[i].get(cls,0)+1
            self.counts[i] += 1
            self.dc.merge(i,v,1.0/self.counts[i])
            return i
    def biniter(self):
        for i in range(self.dc.length()):
            key = ""
            v = self.dc.vector(i)
            count = self.dc.counts(i)
            yield i,v,count,key
    def cls(self,i):
        classes = list(self.classes[i].items())
        classes.sort(reverse=1,key=lambda x:x[1])
        # print i,self.classes[i],classes
        return classes[0]
    def stats(self):
        return " ".join([str(self.total),str(self.dc.length())])
    def save(self,file):
        table = dbtables.ClusterTable(file)
        table.create(image="blob",cls="text",count="integer",classes="text")
        table.converter("image",dbtables.SmallImage())
        for i,v,count,key in self.biniter():
            image = array(v/amax(v)*255.0,'B')
            image.shape = (30,30)
            cls,count = self.cls(i)
            classes = repr(self.classes[i])
            table.set(image=image,cls=cls,count=count,classes=classes)

(options,args) = parser.parse_args()

if len(args)!=2:
    parser.print_help()
    sys.exit(0)

input = args[0]
output = args[1]
if os.path.exists(output):
    if not options.overwrite:
        sys.stderr.write("%s: already exists\n"%output)
        sys.exit(1)
    else:
        os.unlink(output)

ion()
show()

table = dbtables.Table(input,options.table)
table.converter("image",dbtables.SmallImage())
table.create(image="blob",cluster="integer",cls="integer")

total = table.count()
# sample = int(total/100000)
sample = 10
if sample<2: sample = None
sample = table.get(sample_=sample)
ex = ocrolib.ScaledFE()
data = []
for row in sample:
    if amax(row.image.shape)>255: continue
    image = ex.extract(row.image/255.0)
    data.append(image)
data = array(data,'f')
means = quant.kmeans(data,100)

total = 0
binned = FastCluster(options.epsilon)
total = 0
for row in table.get():
    raw = row.image
    cls = row.cls
    if raw.shape[0]>255 or raw.shape[1]>255: continue
    raw = raw/float(amax(raw))
    cluster = binned.add(raw,cls)
    total+=1
    if total%1000==0:
        print "#",total,"chars",binned.stats()
print "#",binned.stats()
binned.save(output)
